import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
import argparse
import torch.nn.functional  as F
import shap
import random

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

col = ['dodgerblue', "tab:orange", "mediumaquamarine", 'lightcoral', 'skyblue', 'sandybrown']

color_noise = 'tab:blue'
color_feature = 'lightcoral'
color_loss = 'tab:blue'
color_acc_train = 'tab:blue'
color_acc_test = 'lightcoral'

# Custom transformation to normalize, resize, and split the MNIST images into three vertical patches
class NormalizeAndSegmentVerticallyTransform:
    def __init__(self, num_patches=3):
        # Normalization parameters (single-channel normalization for grayscale)
        self.normalize = transforms.Normalize((0.5,), (0.5,))  # Normalize grayscale image with mean=0.5, std=0.5
        self.resize = transforms.Resize((27, 27))  # Resize the image to 27x27
        self.num_patches = num_patches

    def __call__(self, img):
        img = transforms.ToTensor()(img)  # Convert image to tensor
        img = self.resize(img)  # Resize to 27x27
        img = self.normalize(img)  # Apply normalization

        img = img.squeeze(0).numpy()
        patches = np.split(img, self.num_patches, axis=1)
        patches = [torch.tensor(patch, dtype=torch.float32).unsqueeze(0) for patch in
                   patches]
        patches = torch.cat(patches, dim=0)
        return patches.view(self.num_patches, -1)


# Filter function to get indices of labels 0 and 1
# def filter_classes(dataset, classes=(0, 1), sample=100):
#     targets = dataset.targets
#
#     indices = (targets == classes[0]) | (targets == classes[1])
#     return torch.nonzero(indices).squeeze()

def filter_classes(dataset, classes=(0, 1), samples_per_class=100):
    targets = dataset.targets
    class_indices = []

    for class_label in classes:
        indices = torch.nonzero(targets == class_label).squeeze()
        class_indices.append(indices[:samples_per_class])  # Take only the first `samples_per_class` samples

    return torch.cat(class_indices)  # Concatenate indices of both classes


def reverse_transform(image, mean=0.5, std=0.5):
    # image input is [batch, 3, 243] or [3, 243]
    shape = image.shape
    if len(shape) == 2:
        image = image.unsqueeze(0)
        shape = image.shape
    image = image.view(shape[0], 3, 27, -1) # [x, 3, 27, 9]
    image = torch.cat((image[:,0,...], image[:,1,...], image[:,2,...]), dim=-1)
    return image * std + mean

class CNN_diff(nn.Module):
    def __init__(self, d, m, num_patches=3):
        super(CNN_diff, self).__init__()
        self.m = m
        self.num_patches = num_patches
        self.W = nn.Parameter(torch.randn(m, d)*0.1)

    def forward(self, x):
        # x is of shape [batch_size, num_patches, d]
        inner_products = torch.einsum('...nd,md->...nm', x, self.W) ** 2
        output = torch.einsum('...nm,md->...nd', inner_products, self.W) / np.sqrt(self.m)
        return output


class CNN_class(nn.Module):
    def __init__(self, d, m, num_patches=3):
        super(CNN_class, self).__init__()
        self.m = m
        self.num_patches = num_patches
        self.Wp = nn.Parameter(torch.randn(m, d) * 0.1)
        self.Wn = nn.Parameter(torch.randn(m, d) * 0.1)

        # nn.init.normal_(self.Wp, std=0.1)
        # nn.init.normal_(self.Wn, std=0.1)

    def forward(self, x):
        # x is of shape [batch_size, num_patches, d]
        inner_products_p = torch.einsum('...nd,md->...nm', x, self.Wp) ** 2
        inner_products_n = torch.einsum('...nd,md->...nm', x, self.Wn) ** 2
        # output = torch.cat([inner_products_p.mean(dim=-1).sum(dim=-1, keepdim=True),
        #                     inner_products_n.mean(dim=-1).sum(-1, keepdim=True)], dim=-1)
        output = inner_products_p.mean(-1).sum(-1) - inner_products_n.mean(-1).sum(-1) # mean over neurons and sum over patch
        return output



# Example usage


# Dummy input of shape [batch_size, num_patches, input_dim]
# x = torch.randn(32, 3, d)
#
# # Forward pass
# output = model(x)
# print(output.shape)




def calculate_accuracy(loader, model):
    correct = 0
    total = 0
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation for evaluation
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)

            f_pred = model(images)
            total += labels.size(0)
            pred_binary = (f_pred > 0).float()
            # pred_binary = torch.argmax(f_pred, dim=-1)
            correct += (pred_binary == labels).sum().item()

    return correct / total



class WrappedModel(nn.Module):
    def __init__(self, model):
        super(WrappedModel, self).__init__()
        self.model = model

    def forward(self, x):
        logit = self.model(x)
        prob_1 = torch.sigmoid(logit)
        prob_0 = 1 - prob_1
        return torch.cat([prob_0.unsqueeze(-1), prob_1.unsqueeze(-1)], dim=-1)



# def visualize_classification(test_dataset):
#     checkpoint = torch.load('class_model_checkpoint.pth')
#     model = CNN_class(d=d, m=m)
#     model.load_state_dict(checkpoint['model_state_dict'])
#
#     model.eval()
#     input_image, label = test_dataset[0]
#     label = label*2 -1
#
#     input_image.requires_grad = True
#
#     output = model(input_image)
#     # loss = torch.log(torch.add(torch.exp(-output * label), 1))
#     output.backward()
#     gradients = input_image.grad
#     # heatmap = np.abs(gradients)
#     # grad_cam = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
#
#     ori_img = reverse_transform(input_image)
#     grad_cam = (reverse_transform(gradients) -0.5)/0.5
#
#     plt.imshow(ori_img.detach().numpy().squeeze(), cmap='gray')
#     plt.show()
#
#     plt.imshow(grad_cam.detach().numpy().squeeze(), cmap='jet', alpha=0.5)  # Overlay heatmap
#     plt.colorbar()
#     plt.show()



def visualize_all(test_dataset, indices):
    checkpoint_class = torch.load('class_model_checkpoint.pth')
    model_class = CNN_class(d=d, m=m)
    model_class.load_state_dict(checkpoint_class['model_state_dict'])

    checkpoint_diff = torch.load('diff_model_checkpoint.pth')
    model_diff = CNN_diff(d=d, m=m)
    model_diff.load_state_dict(checkpoint_diff['model_state_dict'])

    fig, axes = plt.subplots(3, 6, figsize=(10, 6))

    for i, idx in enumerate(indices):
        images, labels = test_dataset[idx]
        original_image = reverse_transform(images).detach().cpu().numpy()

        class_visual = visualize_classification(model_class, images)
        diff_visual = visualize_diffusion(model_diff, images)

        axes[0,i].imshow(original_image[0], cmap='gray')  # You can change cmap based on your images
        axes[1,i].imshow((class_visual[0]), cmap='BuPu')
        axes[2,i].imshow((diff_visual[0]), cmap='BuPu')
        # plt.colorbar(ax_color)
        axes[0,i].axis('off')  # Turn off the axis
        axes[1,i].axis('off')
        axes[2,i].axis('off')

    fig.text(0.02, 0.84, 'Input', va='center', ha='center', rotation='vertical', fontsize=20)
    fig.text(0.02, 0.5, 'Classification', va='center', ha='center', rotation='vertical', fontsize=20)
    fig.text(0.02, 0.17, 'Diffusion', va='center', ha='center', rotation='vertical', fontsize=20)

    plt.subplots_adjust(left=0.05, right=1, top=1, bottom=0.0, wspace=0.08,hspace=0.08)  # Increase the left margin to make room for text

    plt.savefig('figures/diff_class_mnist.png', dpi=300)


    # loss curves
    plt.subplots(figsize=(7, 6))
    plt.plot(checkpoint_class['loss'], color=color_loss, linewidth=3)
    plt.xlabel('Iteration', fontsize=25)
    # plt.ylabel('Loss', fontsize=25)
    plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
    plt.tick_params(axis='both', which='minor', labelsize=15)
    plt.tight_layout()
    plt.savefig('figures/class_loss_mnist.png', dpi=300)

    plt.subplots(figsize=(7, 6))
    plt.plot(checkpoint_diff['loss'], color=color_loss, linewidth=3)
    plt.xlabel('Iteration', fontsize=25)
    # plt.ylabel('Loss', fontsize=25)
    plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
    plt.tick_params(axis='both', which='minor', labelsize=15)
    plt.tight_layout()
    plt.savefig('figures/diff_loss_mnist.png', dpi=300)

    # acc
    plt.subplots(figsize=(7, 6))
    plt.plot(checkpoint_class['train_acc'], label='Train acc', color = color_acc_train,linewidth=3)
    plt.plot(checkpoint_class['test_acc'], label='Test acc', color=color_acc_test,linewidth=3)
    plt.xlabel('Iteration', fontsize=25)
    plt.ylabel('ACC', fontsize=25)
    plt.tick_params(axis='both', which='major', labelsize=15)  # For major ticks
    plt.tick_params(axis='both', which='minor', labelsize=15)
    plt.legend(fontsize=22)
    plt.tight_layout()
    plt.savefig('figures/class_acc_mnist.png', dpi=300)


def visualize_classification(model, images):
    # checkpoint = torch.load('class_model_checkpoint.pth')
    # model = CNN_class(d=d, m=m)
    # model.load_state_dict(checkpoint['model_state_dict'])
    #
    # images, labels = test_dataset[idx]

    inner_p = torch.einsum('...nd,md->...nm', images, model.Wp) / images.norm(dim=-1, keepdim=True)
    inner_n = torch.einsum('...nd,md->...nm', images, model.Wn) / images.norm(dim=-1, keepdim=True)

    p_vals, p_indices = torch.max(inner_p.abs(), dim=-1)
    n_vals, n_indices = torch.max(inner_n.abs(), dim=-1)

    temp_wp = model.Wp[p_indices]
    temp_wn = model.Wn[n_indices]
    # w_temp = torch.cat([temp_wp,temp_wn], dim =0)
    all_vals = torch.cat([p_vals.unsqueeze(0), n_vals.unsqueeze(0)], dim=0)
    all_vals_max, all_vals_ind = all_vals.max(dim=0)

    w_all = []
    for idx, idx_p in enumerate(all_vals_ind):
        if idx_p == 0:
            w_all.append(temp_wp[idx:idx + 1])
        elif idx_p == 1:
            w_all.append(temp_wn[idx:idx + 1])
    w_all = torch.cat(w_all, dim=0)

    # original_image = reverse_transform(images).detach().cpu().numpy()
    denorm_image = reverse_transform(w_all).detach().cpu().numpy()
    denorm_image = (denorm_image - 0.5)/0.5

    return denorm_image

    # plt.imshow(original_image[0], cmap='gray')
    # plt.show()
    #
    # plt.imshow(denorm_image[0], cmap='GnBu')
    # plt.colorbar()
    # plt.show()

    # wrap_model = WrappedModel(model)

    # # select a set of background examples to take an expectation over
    # background = train_images #x_train[np.random.choice(x_train.shape[0], 200, replace=False)]
    #
    # # explain predictions of the model on three images
    # e = shap.DeepExplainer(wrap_model, background)
    # # ...or pass tensors directly
    # # e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
    #
    # inspect_images = test_images[-5:]
    # inspect_labels = test_labels[-5:]
    #
    # shap_values = e.shap_values(inspect_images, check_additivity=False)
    # idx = inspect_labels.unsqueeze(-1).unsqueeze(-1).expand(shap_values.shape[:-1])
    # result = torch.gather(torch.tensor(shap_values), -1, idx.unsqueeze(-1)).squeeze(-1)
    #
    # original_image = reverse_transform(inspect_images).detach().cpu().numpy()
    # shap_values_plot = reverse_transform(result).detach().cpu().numpy()
    # shap_values_plot = (shap_values_plot - 0.5)/0.5
    #
    # plt.imshow(original_image[0])
    # plt.show()
    #
    # plt.imshow(shap_values_plot[0], cmap='viridis')
    # plt.colorbar()
    # plt.show()

    # shap.image_plot(shap_values_plot[0], -original_image[0])

    # shap.image_plot(shap_values, -test_images[0:5])




    # # Zero all gradients before backward pass
    # model.zero_grad()
    #
    # # Forward pass
    # output = model(images)
    # output.backward()
    # gradients = images.grad
    #
    #

    # print(labels)
    # plt.imshow(denorm_image[0], cmap='gray')
    # plt.show()


def visualize_diffusion(model, images):
    # checkpoint = torch.load('diff_model_checkpoint.pth')
    #
    # model = CNN_diff(d=d, m=m)
    # model.load_state_dict(checkpoint['model_state_dict'])
    #
    # images, labels = test_dataset[idx]
    inner = torch.einsum('...nd,md->...nm', images, model.W) / images.norm(dim=-1, keepdim=True)

    p_vals, p_indices = torch.max(inner.abs(), dim=-1)

    w_all = model.W[p_indices]

    # original_image = reverse_transform(images).detach().cpu().numpy()
    denorm_image = reverse_transform(w_all).detach().cpu().numpy()
    denorm_image = (denorm_image - 0.5)/0.5
    return denorm_image

    # plt.imshow(original_image[0], cmap='gray')
    # plt.show()
    #
    # plt.imshow(denorm_image[0], cmap='GnBu')
    # plt.colorbar()
    # plt.show()


def train():
    train_loader = DataLoader(dataset=train_dataset, batch_size=batchsize, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batchsize, shuffle=False)

    if args.model == 'cnn':
        model = CNN_class(d=d, m=m)
        lr = 0.001
        epochs = 500
    elif args.model == 'diff':
        model = CNN_diff(d=d, m=m)
        lr = 0.5
        epochs = 5000
    else:
        raise NotImplementedError

    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    if args.model == 'diff':
        losses = []
        model.train()
        for epoch in range(epochs):
            running_loss = 0.0
            for images, _ in train_loader:
                images = images.to(device)

                eps = torch.randn_like(images)
                noisy_images = alpha * images + beta * eps

                optimizer.zero_grad()

                outputs = model(noisy_images)

                loss = F.mse_loss(outputs, eps)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            loss_log = running_loss / len(train_loader)
            losses.append(loss_log)

            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': losses,
            }

            torch.save(checkpoint, 'diff_model_checkpoint.pth')

            print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")

    elif args.model == 'cnn':
        losses = []
        test_accs = []
        train_accs = []
        for epoch in range(epochs):
            running_loss = 0.0
            model.train()
            for images, labels in train_loader:
                images = images.to(device)
                labels = labels.to(device)

                # convert 0, 1 to -1, 1
                sample_y = labels * 2 - 1

                f_pred = model(images)
                # loss = F.cross_entropy(f_pred, labels)
                loss = torch.log(torch.add(torch.exp(-f_pred * sample_y), 1)).mean()

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            loss_log = running_loss / len(train_loader)
            losses.append(loss_log)

            # train test accuracy
            train_acc = calculate_accuracy(train_loader, model)
            test_acc = calculate_accuracy(test_loader, model)
            train_accs.append(train_acc)
            test_accs.append(test_acc)

            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': losses,
                'test_acc': test_accs,
                'train_acc': train_accs
            }

            torch.save(checkpoint, 'class_model_checkpoint.pth')

            print(
                f"Epoch [{epoch + 1}/{epochs}], Loss: {loss_log:.4f}, Train acc: {train_acc:.2f}, Test acc: {test_acc:.2f}")

            if loss_log < 1e-5:
                break






def parse_args(args):
    parser = argparse.ArgumentParser()

    parser.add_argument("--model", type=str, default='cnn', choices = ['cnn', 'diff'])

    args = parser.parse_args(args)

    return args



# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    seed = 100
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    args = parse_args(None)

    d = 243  # dimension for each patch
    m = 100  # Number of weight vectors
    time = 0.2 # diffusion time
     # number of training epochs
    # lr = 0.01
    train_sample_per_class = 50
    batchsize = train_sample_per_class*2

    alpha = np.exp(-time)
    beta = np.sqrt(1 - np.exp(-2 * time))
    print(alpha, beta)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Custom transform to first normalize and then split into vertical patches
    transform = NormalizeAndSegmentVerticallyTransform()

    # Load the MNIST dataset with the custom transform
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    # Get the indices for classes 0 and 1 in training and test sets
    train_indices = filter_classes(train_dataset, samples_per_class=train_sample_per_class)
    test_indices = filter_classes(test_dataset, samples_per_class=300)

    # Create subsets of the dataset
    train_dataset = Subset(train_dataset, train_indices)
    test_dataset = Subset(test_dataset, test_indices)

    train_images = torch.stack([train_dataset[i][0] for i in range(len(train_dataset))])
    train_labels = torch.tensor([train_dataset[i][1] for i in range(len(train_dataset))], dtype=torch.long)
    test_images = torch.stack([test_dataset[i][0] for i in range(len(test_dataset))])
    test_labels = torch.tensor([test_dataset[i][1] for i in range(len(test_dataset))], dtype=torch.long)


    # training
    # train()

    # visualize
    indices = [0,10,37,305,370,543]

    # visualize_diffusion(test_dataset, idx=543)
    # visualize_classification(test_dataset, idx=543)

    visualize_all(test_dataset, indices)

    #
    #
    # images, labels = test_dataset[1]
    # print(images.shape, labels)
    #
    # denorm_image = reverse_transform(images)
    #
    # # Plot the denormalized image
    # plt.imshow(denorm_image[0].numpy(), cmap='gray')
    # plt.show()



    # # Create DataLoaders for the train and test datasets










